import logging
from unittest.mock import patch

import torch
import torch.nn as nn
from torch.distributed.fsdp._fully_shard._fsdp_common import compiled_autograd_enabled, TrainingState
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam, ShardedState
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup, AllGatherState
from torch.distributed.utils import _to_kwargs
from torch_npu.distributed.fsdp._add_fsdp_patch import _patched_finalize_backward
import torch_npu.distributed.fsdp._add_fsdp_patch as add_fsdp_patch
from torch_npu.testing.testcase import TestCase, run_tests


class TestAddFsdpPatch(TestCase):
    def test_get_param_all_gather_inputs_compiled_autograd(self):

        with patch('torch.distributed.fsdp._fully_shard._fsdp_common.compiled_autograd_enabled', return_value=True):
            class MockFSDPParam:
                def __init__(self):
                    self.all_gather_inputs = [torch.tensor([3.0, 4.0])]
                    self.param_dtype = torch.float32
                    self.offload_to_cpu = False
                    self._sharded_local_tensor = torch.tensor([1.0, 2.0])
                    self.sharded_state = ShardedState.SHARDED
                    self._sharded_param_data = torch.tensor([3.0, 4.0])
                    self._sharded_post_forward_param_data = torch.tensor([5.0, 6.0])
                    self.device = torch.device("cpu")

            fsdp_param = MockFSDPParam()
            fsdp_params = [fsdp_param]
            pass

    def test_patched_finalize_backward_with_events(self):
        class MockFSDPParamGroup:
            def __init__(self):
                self.fsdp_params = []
                self._all_gather_result = MockAllGatherResult()
                self._post_forward_indices = [1, 2, 3]

            def _wait_for_post_backward(self):
                pass

        class MockAllGatherResult:
            def __init__(self):
                self.all_gather_event = MockEvent()
                self.all_gather_work = MockWork()

        class MockEvent:
            def synchronize(self):
                pass

            def wait(self, *args):
                pass

        class MockWork:
            def wait(self):
                pass

        class MockFSDPParam:
            def __init__(self):
                self.grad_offload_event = MockEvent()

        mock_group = MockFSDPParamGroup()
        mock_group.fsdp_params = [MockFSDPParam()]

        _patched_finalize_backward(mock_group)

        self.assertIsNone(mock_group._all_gather_result)
        self.assertEqual(len(mock_group._post_forward_indices), 0)

    def test_get_param_all_gather_inputs_no_foreach_copy(self):
        with patch('torch.distributed.fsdp._fully_shard._fsdp_common.compiled_autograd_enabled', return_value=False):
            class MockFSDPParam:
                def __init__(self):
                    self.param_dtype = torch.float32
                    self.offload_to_cpu = True
                    self._sharded_local_tensor = torch.tensor([1.0, 2.0])
                    self.sharded_state = ShardedState.SHARDED
                    self._sharded_param_data = torch.tensor([3.0, 4.0])
                    self._sharded_post_forward_param_data = torch.tensor([5.0, 6.0])
                    self.device = torch.device("cpu")
                    self.all_gather_inputs = [torch.tensor([7.0, 8.0])]

            fsdp_param = MockFSDPParam()
            fsdp_params = [fsdp_param]
            pass


if __name__ == "__main__":
    run_tests()